%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
from sklearn.datasets import make_moons
from sklearn.preprocessing import normalize
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import numpy as np
import torch
import matplotlib
import itertools
FIG_SIZE = 4
matplotlib.rcParams['figure.figsize'] = (FIG_SIZE, FIG_SIZE)
def scatter(X, title=None, ax=None, show_axes=False):
if ax is None:
ax = plt.gca()
ax.scatter(X[:, 0], X[:, 1], s=10)
if title is not None:
ax.set_title(title)
if not show_axes:
ax.set_xticklabels([])
ax.set_yticklabels([])
def plot(X, y, title=None, ax=None, show_axes=False, label=None):
if ax is None:
ax = plt.gca()
ax.plot(X, y, label=label)
if title is not None:
ax.set_title(title)
if not show_axes:
ax.set_xticklabels([])
ax.set_yticklabels([])
$X \in [0, 1]^{d,N}$
N_SAMPLES = 1000
# moons toy dataset
dataset = make_moons(n_samples=N_SAMPLES)
X, _ = dataset
X = torch.from_numpy(X).float()
# Circle toy dataset
# theta = torch.linspace(0, 2*torch.pi, 1000).unsqueeze(1)
# X = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)
X = (X - X.mean(axis=0)) / X.std(axis=0) # center and normalize
scatter(X, title="Toy dataset", show_axes=True)
A DPDM (Denoising Diffusion Probabilistic Model) is described as follows:
$$ x_{t} = \sqrt{1 - \beta_t}x_{t-1} + \sqrt{\beta_t} \epsilon_{t} $$with $\epsilon_{t} \sim \mathcal{N}(0, 1)$ and $\beta_t \in [0, 1]$.
A global and more intuitive formalization will be:
$$ x_{t} = (1 - \beta_t)^px_{t-1} + {\beta_t}^p \epsilon_{t} $$with $p \in \mathbb{R}_+$.
Bellow, we plot using an increasing uniformly sequence of $\alpha_t$.
interpolation_list = [
lambda alpha: ("Linear scaling on $x_0$ (1)", alpha, (1 - alpha**2)**0.5),
lambda alpha: ("Linear scaling on $\epsilon$ (2)", (1 - (1 - alpha)**2)**0.5, 1 - alpha),
lambda alpha: ("Sqrt combination [model] (3)", alpha**0.5, (1 - alpha)**0.5),
lambda alpha: ("Linear convex combination (4)", alpha, 1 - alpha),
]
def plot_interpolations(interpolation_list, suptitle=None):
n = len(interpolation_list)
alpha = np.linspace(1, 0, 100)
fig, axes = plt.subplots(1, n, figsize=(FIG_SIZE*n, FIG_SIZE), constrained_layout=True)
if suptitle is not None:
fig.suptitle(suptitle)
for i, (interpolation_fct, ax) in enumerate(zip(interpolation_list, axes)):
title, x_coef, eps_coef = interpolation_fct(alpha)
ax.set_title(title)
ax.set_yticks([0, 0.5, 1], ["$x_0$", "$\\frac{x_0 + \epsilon}{2}$", "$\epsilon$"], fontsize=14)
ax.plot(alpha, [0.5] * len(alpha), "--", color="black", alpha=0.3)
ax.plot(alpha, x_coef, label="$x_0$" if i == 0 else None)
ax.plot(alpha, eps_coef, label="$\epsilon$" if i == 0 else None)
ax.fill_between(alpha, eps_coef / (x_coef + eps_coef), alpha=0.1, color="black", label="$x_t$" if i == 0 else None)
ax.set_xlabel("$\overline{\\alpha_t}$")
ax.invert_xaxis()
fig.legend(loc="center right", bbox_to_anchor=(1.1, 0.5), fontsize=14)
plot_interpolations(interpolation_list, suptitle="Evolution of $x_t$ with linear scheduling of $\overline{\\alpha_t}$")
def compare_interpolations(interpolation_list, suptitle=None):
alpha = np.linspace(0, 1, 100)
if suptitle is not None:
plt.suptitle(suptitle)
plt.axhline(0.5, linestyle="--", color="black", alpha=0.3)
for interpolation_fct in interpolation_list:
title, x_coef, eps_coef = interpolation_fct(alpha)
plt.plot(alpha, eps_coef / (x_coef + eps_coef), label=title)
plt.yticks([0, 0.5, 1], ["$x_0$", "$\\frac{x_0 + \epsilon}{2}$", "$\epsilon$"])
plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1))
plt.xlabel("$\overline{\\alpha_t}$")
plt.gca().invert_xaxis()
compare_interpolations(interpolation_list, suptitle="Comparison of $x_t$ with different linear scheduling of $\overline{\\alpha_t}$")
def plot_transformations(interpolation_list, suptitle=None, include_custom_alpha_bar_list=False, n_steps=10):
m = len(interpolation_list)
fig, axes = plt.subplots(m, n_steps, figsize=(n_steps * FIG_SIZE, m * FIG_SIZE), constrained_layout=True)
if suptitle is not None:
fig.suptitle(suptitle, fontsize=20)
eps = torch.normal(0, 1, size=X.shape)
for i, interpolation_fct in enumerate(interpolation_list):
if include_custom_alpha_bar_list:
assert isinstance(interpolation_fct, tuple), "interpolation_fct must be a tuple if include_custom_alpha_bar_list is True"
alpha_bar_list, interpolation_fct = interpolation_fct
alpha_bar_list = alpha_bar_list[::len(alpha_bar_list) // n_steps]
else:
alpha_bar_list = torch.linspace(1, 0, 10)
for j, alpha_bar in enumerate(alpha_bar_list):
title, x_coef, eps_coef, = interpolation_fct(alpha_bar)
if j == 0:
axes[i, j].set_ylabel(title, fontsize=14)
if i == 0 or include_custom_alpha_bar_list:
axes[i, j].set_title(f"$\overline{{\\alpha_t}}={alpha_bar:.1f}$", fontsize=18)
scatter(x_coef*X + eps_coef*eps, ax=axes[i, j], show_axes=True)
plot_transformations(interpolation_list, suptitle="Transformation of $x_t$ using different linear scheduling of $\overline{\\alpha_t}$")
Remind that we are using a linear increasing sequence of $\overline{\alpha_t}$ and NOT of the $\alpha_t$.
This experiment allow us to understand the impact of the $\overline{\alpha_t}$ on the model. Using this framework we should be able to easly compare different sceduling strategies and compare their impact on the transformation of the $x_t$.
T = 1000
def compare_alpha_bar_evolutions(interpolation_list, suptitle=None):
alpha = np.linspace(0, 1, T)
if suptitle is not None:
plt.suptitle(suptitle)
for interpolation_fct in interpolation_list:
assert isinstance(interpolation_fct, tuple), "interpolation_fct must be a tuple"
alpha_bar_list, interpolation_fct = interpolation_fct
title, *_ = interpolation_fct(alpha)
plt.plot(np.linspace(0, 1, T), alpha_bar_list, label=title)
plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1))
plt.xlabel("diffusion step ($t / T$)")
plt.ylabel("$\overline{\\alpha_t}$")
def get_fixed_alpha_bars(cst):
return (np.ones(T) * cst).cumprod(0)
interpolation_list = [
(get_fixed_alpha_bars(1 - 1e-1), lambda alpha_bar: ("$\\beta_t=10^{-1}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-2), lambda alpha_bar: ("$\\beta_t=10^{-2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-2/2), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-2/3), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-3), lambda alpha_bar: ("$\\beta_t=10^{-3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_fixed_alpha_bars(1 - 1e-4), lambda alpha_bar: ("$\\beta_t=10^{-4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with fixed $\\alpha$ scheduling")
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with fixed $\\alpha$ scheduling")
def get_linear_alpha_bars(start, end):
return (np.linspace(start, end, T)).cumprod(0)
interpolation_list = [
(get_linear_alpha_bars(1 - 1e-4, 1 - 0.02), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 0.02]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)), # from "Denoising Diffusion Probabilistic Models"
(get_linear_alpha_bars(1 - 1e-5, 1 - 1e-1), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-1}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-2}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-5, 1 - 1e-3), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-3}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-4, 1 - 1e-1), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-1}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-4, 1 - 1e-2), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-2}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-4, 1 - 1e-3), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-3}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with linear $\\alpha$ scheduling")
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with linear $\\alpha$ scheduling")
def get_cosine_alpha_bars(s):
def f(t):
return np.cos((t/T + s)/(1 + s) * np.pi/2)**2
return f(np.linspace(0, T, T))
interpolation_list = [
(get_cosine_alpha_bars(0.008), lambda alpha_bar: ("$s = 0.008$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-1), lambda alpha_bar: ("$s = 10^{-1}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("$s = 10^{-2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-3), lambda alpha_bar: ("$s = 10^{-3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-4), lambda alpha_bar: ("$s = 10^{-4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-8), lambda alpha_bar: ("$s = 10^{-8}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with cosine $\\alpha$ scheduling")
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with cosine $\\alpha$ scheduling")
selected_interpolation_list = [
(get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("fixed", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("linear", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
(get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("cosine", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
selected_interpolation_list_using_custom_formulation = [
(get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("fixed", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
(get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("linear", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
(get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("cosine", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
]
plot_transformations(selected_interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with selected $\\alpha$ scheduling (using sqrt combination)")
plot_transformations(selected_interpolation_list_using_custom_formulation, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with selected $\\alpha$ scheduling (using sqrt combination)")
compare_alpha_bar_evolutions(selected_interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with selected $\\alpha$ scheduling")
def forward(X_0, beta_scheduler):
beta_scheduler = torch.cat([torch.zeros(1), beta_scheduler]) # add beta_0
T = len(beta_scheduler)
alpha_bar_list = (1 - beta_scheduler).cumprod(dim=0)
def qx_t(t):
alpha_bar = alpha_bar_list[t]
return alpha_bar**0.5 * X_0 + (1 - alpha_bar)**0.5 * torch.normal(0, 1, X_0.shape)
X_s = [qx_t(t) for t in range(T)]
return X_s, alpha_bar_list